#ifndef _ZQ_CPP_LIB_GRADIENT_COMPRESSION_BODY_TERNGRAD_BODY_H
#define _ZQ_CPP_LIB_GRADIENT_COMPRESSION_BODY_TERNGRAD_BODY_H
#include <stdint.h>
#include <thrust/random.h>
#include <thrust/sort.h>                       //sort()
#include <thrust/execution_policy.h>           //thrust::device
#include <thrust/functional.h>                 //greater<float>
#include <thrust/copy.h>                       //copy_if
#include <thrust/iterator/counting_iterator.h> // counting_iterator
#include <thrust/transform.h>                  //trnasform
#include "../naive_random.hpp"
#include "../operate_memory/get_policy_general.h"
#include "../likely.h"
#include <chrono>

namespace zq_cpp_lib
{
namespace gradient_compression_body
{
using namespace zq_cpp_lib::operate_memory;


struct compress_without_random{
  float* in_float;
  uint8_t bitwidth;
  uint8_t data_per_byte_lg2;
  float min_val;
  float gap_inverse;
  compress_without_random(
    float* a,
    uint8_t b,
    uint8_t c,
    float d,
    float e
  ){
    in_float = a;
    bitwidth = b;
    data_per_byte_lg2 = c;
    min_val = d;
    gap_inverse = e;
  }
  __host__ __device__
  uint8_t operator()(const int32_t& i){
    uint8_t qval = 0;
    int j;
    float thetimes;
    uint8_t t;
#pragma unroll
    for (j = 0; j < (1<<data_per_byte_lg2); j++){
      thetimes = (in_float[(i<<data_per_byte_lg2) + j] - min_val) * gap_inverse;
      t = nearbyint(thetimes);
      qval |= (t << (bitwidth*j));
    };
    return qval;
  }
};


struct compress_with_random{
  float* in_float;
  uint8_t bitwidth;
  uint8_t data_per_byte_lg2;
  float min_val;
  float gap_inverse;
  unsigned long long timestamp;
  compress_with_random(
    float* a,
    uint8_t b,
    uint8_t c,
    float d,
    float e,
    unsigned long long f
  ){
    in_float = a;
    bitwidth = b;
    data_per_byte_lg2 = c;
    min_val = d;
    gap_inverse = e;
    timestamp = f;
  }
  __host__ __device__
  uint8_t operator()(const int32_t& i){
    uint8_t qval = 0;
    int j;
    float thetimes;
    uint8_t t;
    zq_cpp_lib::naive_real_random<float> r(0.0,1.0);
    r.srand(timestamp+i);
#pragma unroll
    for (j = 0; j < (1<<data_per_byte_lg2); j++){
      thetimes = (in_float[(i<<data_per_byte_lg2) + j] - min_val) * gap_inverse;
      thetimes += r();
      t = static_cast<uint8_t>(thetimes);
      qval |= (t << (bitwidth*j));
    };
    return qval;
  }
};

struct decompress_write_to{
  uint8_t* in_uint8_t;
  uint8_t bitwidth;
  uint8_t data_per_byte_lg2;
  float min_f;
  float gap;
  decompress_write_to(
    uint8_t* a,
    uint8_t b,
    uint8_t c,
    float d,
    float e
  ){
    in_uint8_t = a;
    bitwidth = b;
    data_per_byte_lg2 = c;
    min_f = d;
    gap = e;
  }
  __host__ __device__
  float operator()(const int32_t& i){
    int32_t input_index = (i >> data_per_byte_lg2) + 10;
    uint8_t input_offset = i & ((1 << data_per_byte_lg2) - 1);
    uint8_t mask = (1 << bitwidth) - 1;
    uint8_t qval = (in_uint8_t[input_index] >> (input_offset * bitwidth)) & mask;
    return static_cast<float>(qval*gap + min_f);
  }
};
struct decompress_add_to{
  uint8_t* in_uint8_t;
  float* out_float;
  uint8_t bitwidth;
  uint8_t data_per_byte_lg2;
  float min_f;
  float gap;
  decompress_add_to(
    uint8_t* a,
    float* out_float_,
    uint8_t b,
    uint8_t c,
    float d,
    float e
  ){
    in_uint8_t = a;
    out_float = out_float_;
    bitwidth = b;
    data_per_byte_lg2 = c;
    min_f = d;
    gap = e;
  }
  __host__ __device__
  float operator()(const int32_t& i){
    int32_t input_index = (i >> data_per_byte_lg2) + 10;
    uint8_t input_offset = i & ((1 << data_per_byte_lg2) - 1);
    uint8_t mask = (1 << bitwidth) - 1;
    uint8_t qval = (in_uint8_t[input_index] >> (input_offset * bitwidth)) & mask;
    return static_cast<float>(qval*gap + min_f + out_float[i]);
  }
};


template <typename policy_t>
int terngrad_body(
    float* in_float,
    int32_t in_float_size,
    uint8_t* out_uint8_t,
    int32_t out_uint8_t_size,
    uint8_t bitwidth,
    int32_t random,
    policy_t policy,
    void *stream)
{
  float min_val,max_val;
  // MIN_MAX(in_float,in_float_size,&min_val, &max_val);
  auto min_max = thrust::minmax_element(
    policy,
    in_float,
    in_float+in_float_size
  );
  // cudaMemcpyAsync(&min_val,min_max.first, sizeof(min_val), cudaMemcpyDeviceToHost, mshadow::Stream<gpu>::GetStream(s));
  // cudaMemcpyAsync(&max_val,min_max.second,sizeof(max_val), cudaMemcpyDeviceToHost, mshadow::Stream<gpu>::GetStream(s));
  // cudaStreamSynchronize(mshadow::Stream<gpu>::GetStream(s));
  get_policy<policy_t>::memcpyOut(&min_val,min_max.first,sizeof(min_val),stream);
  get_policy<policy_t>::memcpyOut(&max_val,min_max.second,sizeof(max_val),stream);
  get_policy<policy_t>::streamSynchronize(stream);
  float gap = (max_val - min_val) / ((1 << bitwidth) - 1.0f);
  float gap_inverse = 1. / (gap + 1e-8);

  uint8_t lg2[9] = {0,0,1,1,2,2,2,2,3};
  uint8_t bitwidth_lg2 = lg2[bitwidth];
  if (unlikely((1<<bitwidth_lg2)!=bitwidth)){
    printf("Invalid value of bitwidth, chekc value: bitwidth=%d\n",bitwidth+0);
    return -1;
  }

  uint8_t data_per_byte_lg2 = 3 - bitwidth_lg2;
  uint8_t data_per_byte = 1<<data_per_byte_lg2;
  uint8_t tail = in_float_size%data_per_byte;
  tail = tail == 0? 0 : data_per_byte - tail;

  uint8_t header[10];
  ((float*)(header+2))[0] = min_val;
  ((float*)(header+6))[0] = max_val;
  header[0] = bitwidth;
  header[1] = tail;
  // cudaMemcpyAsync(out_uint8_t,header,sizeof(uint8_t)*10,cudaMemcpyHostToDevice,mshadow::Stream<gpu>::GetStream(s));
  get_policy<policy_t>::memcpyIn(out_uint8_t, header, sizeof(uint8_t)*10, stream);

  thrust::counting_iterator<int32_t> index_sequence_begin(0);
  if (random){
    thrust::transform(
      policy,
      index_sequence_begin,
      index_sequence_begin + (in_float_size >> data_per_byte_lg2),
      out_uint8_t+10,
      compress_with_random(
        in_float,
        bitwidth,
        data_per_byte_lg2,
        min_val,
        gap_inverse,
        static_cast<unsigned long long>(
          std::chrono::high_resolution_clock::now()
          .time_since_epoch()
          .count()
        )
      )
    );
  }
  else{
    thrust::transform(
      policy,
      index_sequence_begin,
      index_sequence_begin + (in_float_size >> data_per_byte_lg2),
      out_uint8_t+10,
      compress_without_random(
        in_float,
        bitwidth,
        data_per_byte_lg2,
        min_val,
        gap_inverse
      )
    );
  }

  uint8_t qval = 0;
  if (tail){
    float tail_data[8];
    // cudaMemcpy(tail_data,in_float+in_float_size-data_per_byte,sizeof(float)*(data_per_byte-tail),cudaMemcpyDeviceToHost);
    // in_float_size - data_per_byte?
    get_policy<policy_t>::memcpyOut(tail_data,in_float+(in_float_size-(data_per_byte-tail)),sizeof(float)*(data_per_byte-tail),stream);
    get_policy<policy_t>::streamSynchronize(stream);
    for (auto i = 0; i < data_per_byte - tail; i++){
      uint8_t t = nearbyint((tail_data[i] - min_val)*gap_inverse);
      qval = qval | ( t << (bitwidth*i));
    };
    // cudaMemcpyAsync(out_uint8_t+out_uint8_t_size-1,&qval,sizeof(uint8_t),cudaMemcpyHostToDevice,mshadow::Stream<gpu>::GetStream(s));
    get_policy<policy_t>::memcpyIn(out_uint8_t+(out_uint8_t_size-1), &qval, sizeof(uint8_t), stream);
  };

  return 0;
}

template <typename policy_t>
int terngrad_r_body(
    float* out_float,
    int32_t out_float_size,
    uint8_t* in_uint8_t,
    int32_t in_uint8_t_size,
    int is_add_to,
    policy_t policy,
    void *stream)
{
  uint8_t header[10];
  // cudaMemcpy(header,in_uint8_t,10*sizeof(uint8_t),cudaMemcpyDeviceToHost);
  get_policy<policy_t>::memcpyOutSync(header,in_uint8_t,10*sizeof(uint8_t));
    float min_val = *((float*)(header+2));
    float max_val = *((float*)(header+6));
    uint8_t bitwidth = header[0];
    uint8_t tail = header[1];
  float gap = (max_val - min_val) / ((1 << bitwidth) - 1.0f);
  //ChECK_EQ(tail,header[1]);  // if tail != header, use 0 to replace

  uint8_t lg2[9] = {0,0,1,1,2,2,2,2,3};
  uint8_t bitwidth_lg2 = lg2[bitwidth];
  if (unlikely((1<<bitwidth_lg2)!=bitwidth)){
    printf("bitwidth is invalid, check value: %d\n",bitwidth+0);
    return -1;
  }

  uint8_t data_per_byte_lg2 = 3 - bitwidth_lg2;
  //uint8_t data_per_byte = 1<<data_per_byte_lg2;


  // mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();


  thrust::counting_iterator<int32_t> index_sequence_begin(0);
  if (is_add_to){
    thrust::transform(
      policy,
      index_sequence_begin,
      index_sequence_begin + (((in_uint8_t_size-10)<<data_per_byte_lg2)-tail),
      out_float,
      decompress_add_to(
        in_uint8_t,
        out_float,
        bitwidth,
        data_per_byte_lg2,
        min_val,
        gap
      )
    );
  }
  else{
    thrust::transform(
      policy,
      index_sequence_begin,
      index_sequence_begin + (((in_uint8_t_size-10) << data_per_byte_lg2) - tail),
      out_float,
      decompress_write_to(
        in_uint8_t,
        bitwidth,
        data_per_byte_lg2,
        min_val,
        gap
      )
    );
  }


  return 0;
}

} // namespace gradient_compression_body
} // namespace zq_cpp_lib

#endif
